"""
Phase 1: Download bilateral capital flow data and construct gravity panel.

Downloads:
  1a. CPIS (Coordinated Portfolio Investment Survey) from IMF PIP database
  1b. CDIS (Coordinated Direct Investment Survey) from IMF DIP database
  1c. CEPII GeoDist gravity variables
  1d. Merge all with demographics from followup panel

Output: gravity_bilateral/data/processed/bilateral_panel.csv
"""

import pandas as pd
import numpy as np
from pathlib import Path
import time
import requests
import io

# Paths
BASE_DIR = Path("/mnt/c/demographics_capital_flows/gravity_bilateral")
RAW_DIR = BASE_DIR / "data" / "raw"
PROCESSED_DIR = BASE_DIR / "data" / "processed"
OUTPUT_DIR = BASE_DIR / "output" / "tables"
FOLLOWUP_DIR = Path("/mnt/c/demographics_capital_flows/multilateral/followup")

RAW_DIR.mkdir(parents=True, exist_ok=True)
PROCESSED_DIR.mkdir(parents=True, exist_ok=True)
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)


# -----------------------------------------------------------------------
# 1a. Download CPIS bilateral portfolio investment positions
# -----------------------------------------------------------------------

def download_cpis():
    """Download CPIS bilateral portfolio positions from IMF PIP database."""
    cache = RAW_DIR / "cpis_bilateral.csv"
    if cache.exists():
        print(f"  CPIS cached: {cache}")
        df = pd.read_csv(cache)
        print(f"  Loaded {len(df):,} rows")
        return df

    import imfp

    print("  Downloading CPIS bilateral data from IMF PIP database...")

    # Get country list
    params = imfp.imf_parameters('PIP')
    all_countries = params['country']['input_code'].tolist()
    # Filter to actual countries (3-letter ISO codes, not aggregates)
    countries = [c for c in all_countries
                 if len(c) == 3 and not c.startswith(('TX', 'GX', 'W0', 'W1'))]
    print(f"  {len(countries)} reporter countries")

    # Indicators: total portfolio, equity, debt
    indicators = ['P_TOTINV_P_USD', 'P_F51_P_USD', 'P_F3_P_USD']
    indicator_labels = {
        'P_TOTINV_P_USD': 'portfolio_total',
        'P_F51_P_USD': 'portfolio_equity',
        'P_F3_P_USD': 'portfolio_debt',
    }

    all_data = []
    chunk_size = 15  # Smaller chunks for bilateral data

    for ind in indicators:
        print(f"\n  Indicator: {ind} ({indicator_labels[ind]})")
        for i in range(0, len(countries), chunk_size):
            chunk = countries[i:i + chunk_size]
            print(f"    Chunk {i // chunk_size + 1}/{(len(countries) + chunk_size - 1) // chunk_size}: "
                  f"{chunk[0]}-{chunk[-1]}...", end=" ")
            try:
                data = imfp.imf_dataset(
                    database_id='PIP',
                    indicator=ind,
                    country=chunk,
                    accounting_entry='A',  # Assets = outward holdings
                    frequency='A',
                    start_year=2001,
                    end_year=2024,
                )
                if data is not None and len(data) > 0:
                    data['indicator_label'] = indicator_labels[ind]
                    all_data.append(data)
                    print(f"{len(data):,} rows")
                else:
                    print("no data")
            except Exception as e:
                print(f"error: {e}")
            time.sleep(1.5)  # Rate limiting

    if not all_data:
        print("  WARNING: No CPIS data downloaded!")
        return pd.DataFrame()

    df = pd.concat(all_data, ignore_index=True)
    print(f"\n  Raw CPIS: {len(df):,} rows")

    # Clean columns
    # imfp returns: country, counterpart_country, time_period, obs_value, indicator, etc.
    cols_map = {}
    for c in df.columns:
        cl = c.lower()
        if cl in ('ref_area', 'country'):
            cols_map[c] = 'reporter'
        elif cl in ('counterpart_area', 'counterpart_country'):
            cols_map[c] = 'partner'
        elif cl in ('time_period',):
            cols_map[c] = 'year'
        elif cl in ('obs_value', 'value'):
            cols_map[c] = 'value'
    df = df.rename(columns=cols_map)

    # Ensure we have the required columns
    required = ['reporter', 'partner', 'year', 'value', 'indicator_label']
    available = [c for c in required if c in df.columns]
    if len(available) < 4:
        print(f"  WARNING: Missing columns. Available: {list(df.columns)}")
        # Try to work with what we have
        df.to_csv(cache, index=False)
        return df

    df = df[available].copy()
    df['year'] = pd.to_numeric(df['year'], errors='coerce')
    df['value'] = pd.to_numeric(df['value'], errors='coerce')

    # Remove aggregate partners (keep only country-to-country)
    df = df[df['partner'].str.len() == 3].copy()
    df = df[~df['partner'].str.startswith(('TX', 'GX', 'W0', 'W1'))].copy()
    # Remove self-investments
    df = df[df['reporter'] != df['partner']].copy()

    df.to_csv(cache, index=False)
    print(f"  Saved CPIS: {len(df):,} rows to {cache}")
    return df


# -----------------------------------------------------------------------
# 1b. Download CDIS bilateral direct investment positions
# -----------------------------------------------------------------------

def download_cdis():
    """Download CDIS bilateral direct investment positions from IMF DIP database."""
    cache = RAW_DIR / "cdis_bilateral.csv"
    if cache.exists():
        print(f"  CDIS cached: {cache}")
        df = pd.read_csv(cache)
        print(f"  Loaded {len(df):,} rows")
        return df

    import imfp

    print("  Downloading CDIS bilateral data from IMF DIP database...")

    params = imfp.imf_parameters('DIP')
    all_countries = params['country']['input_code'].tolist()
    countries = [c for c in all_countries
                 if len(c) == 3 and not c.startswith(('TX', 'GX', 'W0', 'W1'))]
    print(f"  {len(countries)} reporter countries")

    # Outward DI: net (assets less liabilities), all instruments, all entities
    indicator = 'OTWD_D_NETAL_FALL_ALL'

    all_data = []
    chunk_size = 15

    for i in range(0, len(countries), chunk_size):
        chunk = countries[i:i + chunk_size]
        print(f"    Chunk {i // chunk_size + 1}/{(len(countries) + chunk_size - 1) // chunk_size}: "
              f"{chunk[0]}-{chunk[-1]}...", end=" ")
        try:
            data = imfp.imf_dataset(
                database_id='DIP',
                indicator=indicator,
                country=chunk,
                frequency='A',
                start_year=2009,
                end_year=2024,
            )
            if data is not None and len(data) > 0:
                all_data.append(data)
                print(f"{len(data):,} rows")
            else:
                print("no data")
        except Exception as e:
            print(f"error: {e}")
        time.sleep(1.5)

    if not all_data:
        print("  WARNING: No CDIS data downloaded!")
        return pd.DataFrame()

    df = pd.concat(all_data, ignore_index=True)
    print(f"\n  Raw CDIS: {len(df):,} rows")

    # Clean columns (same pattern as CPIS)
    cols_map = {}
    for c in df.columns:
        cl = c.lower()
        if cl in ('ref_area', 'country'):
            cols_map[c] = 'reporter'
        elif cl in ('counterpart_area', 'counterpart_country'):
            cols_map[c] = 'partner'
        elif cl in ('time_period',):
            cols_map[c] = 'year'
        elif cl in ('obs_value', 'value'):
            cols_map[c] = 'value'
    df = df.rename(columns=cols_map)

    required = ['reporter', 'partner', 'year', 'value']
    available = [c for c in required if c in df.columns]
    df = df[available].copy()
    df['year'] = pd.to_numeric(df['year'], errors='coerce')
    df['value'] = pd.to_numeric(df['value'], errors='coerce')
    df['indicator_label'] = 'fdi_outward'

    # Filter to country pairs
    df = df[df['partner'].str.len() == 3].copy()
    df = df[~df['partner'].str.startswith(('TX', 'GX', 'W0', 'W1'))].copy()
    df = df[df['reporter'] != df['partner']].copy()

    df.to_csv(cache, index=False)
    print(f"  Saved CDIS: {len(df):,} rows to {cache}")
    return df


# -----------------------------------------------------------------------
# 1c. Download CEPII GeoDist gravity variables
# -----------------------------------------------------------------------

def download_cepii_geodist():
    """Download CEPII GeoDist gravity variables."""
    cache = RAW_DIR / "cepii_geodist.csv"
    if cache.exists():
        print(f"  CEPII cached: {cache}")
        df = pd.read_csv(cache)
        print(f"  Loaded {len(df):,} rows")
        return df

    print("  Downloading CEPII GeoDist data...")

    # CEPII distributes via ZIP (containing XLS)
    urls_zip = [
        "http://www.cepii.fr/distance/dist_cepii.zip",
        "https://www.cepii.fr/distance/dist_cepii.zip",
    ]
    urls_csv = [
        "http://www.cepii.fr/distance/dist_cepii.csv",
        "https://www.cepii.fr/distance/dist_cepii.csv",
    ]

    df = None
    # Try ZIP first (more reliable)
    for url in urls_zip:
        try:
            print(f"    Trying ZIP: {url}")
            resp = requests.get(url, timeout=30)
            if resp.status_code == 200 and len(resp.content) > 1000:
                import zipfile
                zip_path = RAW_DIR / "dist_cepii.zip"
                with open(zip_path, 'wb') as f:
                    f.write(resp.content)
                with zipfile.ZipFile(zip_path) as z:
                    z.extractall(RAW_DIR)
                xls_path = RAW_DIR / "dist_cepii.xls"
                if xls_path.exists():
                    df = pd.read_excel(xls_path)
                    print(f"    Success (XLS): {len(df):,} rows")
                break
        except Exception as e:
            print(f"    Failed: {e}")

    # Fallback: try CSV
    if df is None:
        for url in urls_csv:
            try:
                print(f"    Trying CSV: {url}")
                resp = requests.get(url, timeout=30)
                if resp.status_code == 200:
                    df = pd.read_csv(io.StringIO(resp.text))
                    print(f"    Success: {len(df):,} rows")
                    break
            except Exception as e:
                print(f"    Failed: {e}")

    if df is None:
        print("  CEPII download failed. Constructing minimal gravity vars from scratch.")
        return _construct_minimal_gravity()

    # Standardize column names
    rename = {
        'iso_o': 'iso_o', 'iso_d': 'iso_d',
        'distw': 'dist_weighted',
        'dist': 'dist_simple',
        'contig': 'contiguity',
        'comlang_off': 'common_lang_official',
        'comlang_ethno': 'common_lang_ethno',
        'colony': 'colonial_ties',
        'comcol': 'common_colonizer',
    }
    avail_rename = {k: v for k, v in rename.items() if k in df.columns}
    df = df.rename(columns=avail_rename)

    # Keep relevant columns
    keep = ['iso_o', 'iso_d']
    for col in ['dist_weighted', 'dist_simple', 'contiguity',
                'common_lang_official', 'common_lang_ethno',
                'colonial_ties', 'common_colonizer']:
        if col in df.columns:
            keep.append(col)
    df = df[keep].copy()

    # Ensure numeric distance (CEPII XLS uses "." for missing)
    df['dist_weighted'] = pd.to_numeric(df['dist_weighted'], errors='coerce')
    if 'dist_simple' in df.columns:
        df['dist_simple'] = pd.to_numeric(df['dist_simple'], errors='coerce')

    # Use dist_simple as fallback for dist_weighted
    if 'dist_simple' in df.columns:
        df.loc[df['dist_weighted'].isna(), 'dist_weighted'] = df.loc[
            df['dist_weighted'].isna(), 'dist_simple']
    df = df.dropna(subset=['dist_weighted'])

    df.to_csv(cache, index=False)
    print(f"  Saved CEPII: {len(df):,} rows to {cache}")
    return df


def _construct_minimal_gravity():
    """Fallback: construct minimal gravity vars from country coordinates."""
    print("  Constructing minimal gravity variables from coordinates...")

    # Load our country panel to get ISO3 codes
    panel = pd.read_csv(FOLLOWUP_DIR / "data" / "processed" / "full_panel.csv",
                        usecols=['iso3'])
    countries = sorted(panel['iso3'].unique())

    # Capital city approximate coordinates for major countries
    coords = {
        'USA': (38.9, -77.0), 'GBR': (51.5, -0.1), 'DEU': (52.5, 13.4),
        'FRA': (48.9, 2.3), 'JPN': (35.7, 139.7), 'CHN': (39.9, 116.4),
        'IND': (28.6, 77.2), 'BRA': (-15.8, -47.9), 'AUS': (-35.3, 149.1),
        'CAN': (45.4, -75.7), 'KOR': (37.6, 127.0), 'MEX': (19.4, -99.1),
        'IDN': (-6.2, 106.8), 'RUS': (55.8, 37.6), 'ZAF': (-25.7, 28.2),
        'SAU': (24.7, 46.7), 'NGA': (9.1, 7.5), 'TUR': (39.9, 32.9),
        'ARG': (-34.6, -58.4), 'ITA': (41.9, 12.5), 'ESP': (40.4, -3.7),
        'NLD': (52.4, 4.9), 'CHE': (46.9, 7.4), 'SWE': (59.3, 18.1),
        'NOR': (59.9, 10.7), 'SGP': (1.3, 103.8), 'HKG': (22.3, 114.2),
        'TWN': (25.0, 121.5), 'THA': (13.8, 100.5), 'MYS': (3.1, 101.7),
        'PHL': (14.6, 121.0), 'VNM': (21.0, 105.8), 'EGY': (30.0, 31.2),
        'ISR': (31.8, 35.2), 'ARE': (24.5, 54.7), 'KWT': (29.4, 47.9),
        'QAT': (25.3, 51.5), 'CHL': (-33.4, -70.7), 'COL': (4.7, -74.1),
        'PER': (-12.0, -77.0), 'POL': (52.2, 21.0), 'CZE': (50.1, 14.4),
    }

    rows = []
    country_list = [c for c in countries if c in coords]
    for i, iso_o in enumerate(country_list):
        for iso_d in country_list:
            if iso_o == iso_d:
                continue
            lat1, lon1 = coords[iso_o]
            lat2, lon2 = coords[iso_d]
            # Haversine distance in km
            dlat = np.radians(lat2 - lat1)
            dlon = np.radians(lon2 - lon1)
            a = (np.sin(dlat / 2) ** 2 +
                 np.cos(np.radians(lat1)) * np.cos(np.radians(lat2)) *
                 np.sin(dlon / 2) ** 2)
            dist = 2 * 6371 * np.arcsin(np.sqrt(a))
            rows.append({
                'iso_o': iso_o, 'iso_d': iso_d,
                'dist_weighted': dist,
                'contiguity': 0,
                'common_lang_official': 0,
                'common_lang_ethno': 0,
                'colonial_ties': 0,
                'common_colonizer': 0,
            })

    df = pd.DataFrame(rows)
    cache = RAW_DIR / "cepii_geodist.csv"
    df.to_csv(cache, index=False)
    print(f"  Saved minimal gravity: {len(df):,} pairs")
    return df


# -----------------------------------------------------------------------
# 1d. Merge and construct bilateral panel
# -----------------------------------------------------------------------

def construct_bilateral_panel(cpis_df, cdis_df, gravity_df):
    """Merge all data sources into a bilateral panel."""
    print("\n" + "=" * 70)
    print("CONSTRUCTING BILATERAL PANEL")
    print("=" * 70)

    # --- Load demographics ---
    panel = pd.read_csv(FOLLOWUP_DIR / "data" / "processed" / "full_panel.csv")
    panel = panel[panel['year'] <= 2024].copy()
    print(f"  Demographics panel: {panel['iso3'].nunique()} countries, "
          f"{panel['year'].min()}-{panel['year'].max()}")

    # --- Pivot CPIS to wide format ---
    if cpis_df is not None and len(cpis_df) > 0 and 'indicator_label' in cpis_df.columns:
        cpis_wide = cpis_df.pivot_table(
            index=['reporter', 'partner', 'year'],
            columns='indicator_label',
            values='value',
            aggfunc='first'
        ).reset_index()
        cpis_wide.columns.name = None
        print(f"  CPIS wide: {len(cpis_wide):,} bilateral-year obs")
    else:
        cpis_wide = pd.DataFrame(columns=['reporter', 'partner', 'year'])
        print("  CPIS: no data available")

    # --- Merge CDIS ---
    if cdis_df is not None and len(cdis_df) > 0:
        cdis_clean = cdis_df[['reporter', 'partner', 'year', 'value']].copy()
        cdis_clean = cdis_clean.rename(columns={'value': 'fdi_outward'})
        cdis_clean = cdis_clean.drop_duplicates(subset=['reporter', 'partner', 'year'])
        print(f"  CDIS: {len(cdis_clean):,} bilateral-year obs")
    else:
        cdis_clean = pd.DataFrame(columns=['reporter', 'partner', 'year', 'fdi_outward'])
        print("  CDIS: no data available")

    # --- Combine CPIS and CDIS ---
    if len(cpis_wide) > 0 and len(cdis_clean) > 0:
        bilateral = cpis_wide.merge(cdis_clean, on=['reporter', 'partner', 'year'], how='outer')
    elif len(cpis_wide) > 0:
        bilateral = cpis_wide.copy()
    elif len(cdis_clean) > 0:
        bilateral = cdis_clean.copy()
    else:
        print("  ERROR: No bilateral flow data available!")
        return None

    print(f"  Combined bilateral: {len(bilateral):,} obs, "
          f"{bilateral['reporter'].nunique()} reporters, "
          f"{bilateral['partner'].nunique()} partners")

    # --- Merge gravity variables ---
    if gravity_df is not None and len(gravity_df) > 0:
        bilateral = bilateral.merge(
            gravity_df,
            left_on=['reporter', 'partner'],
            right_on=['iso_o', 'iso_d'],
            how='left'
        )
        grav_coverage = bilateral['dist_weighted'].notna().mean()
        print(f"  Gravity merge coverage: {grav_coverage:.1%}")

        # Drop pairs without gravity data (no distance = can't estimate gravity)
        n_before = len(bilateral)
        bilateral = bilateral.dropna(subset=['dist_weighted'])
        print(f"  Dropped {n_before - len(bilateral):,} obs without gravity vars "
              f"({len(bilateral):,} remaining)")

    # --- Merge demographics for reporter (origin) ---
    demo_cols = ['iso3', 'year', 'Z_1', 'Z_2', 'Z_3', 'kaopen',
                 'ngdp_usd', 'ca_gdp', 'nfa_gdp', 'fiscal_bal_gdp',
                 'expected_growth', 'nfa_gdp_lag', 'log_rel_opw',
                 'real_bond_10y_diff', 'log_lending_rate',
                 'life_expectancy', 'gdp_pc_ppp']
    demo_cols = [c for c in demo_cols if c in panel.columns]

    # Reporter demographics
    reporter_demo = panel[demo_cols].copy()
    reporter_rename = {c: f'{c}_i' for c in demo_cols if c not in ['iso3', 'year']}
    reporter_demo = reporter_demo.rename(columns=reporter_rename)
    bilateral = bilateral.merge(
        reporter_demo,
        left_on=['reporter', 'year'],
        right_on=['iso3', 'year'],
        how='left'
    ).drop(columns=['iso3'], errors='ignore')

    # Partner demographics
    partner_demo = panel[demo_cols].copy()
    partner_rename = {c: f'{c}_j' for c in demo_cols if c not in ['iso3', 'year']}
    partner_demo = partner_demo.rename(columns=partner_rename)
    bilateral = bilateral.merge(
        partner_demo,
        left_on=['partner', 'year'],
        right_on=['iso3', 'year'],
        how='left'
    ).drop(columns=['iso3'], errors='ignore')

    # --- Ensure numeric flow columns ---
    for fc in ['portfolio_total', 'portfolio_equity', 'portfolio_debt', 'fdi_outward']:
        if fc in bilateral.columns:
            bilateral[fc] = pd.to_numeric(bilateral[fc], errors='coerce')
    for nc in bilateral.columns:
        if nc.endswith(('_i', '_j')) and nc not in ['iso_o', 'iso_d']:
            bilateral[nc] = pd.to_numeric(bilateral[nc], errors='coerce')

    # --- Construct bilateral variables ---

    # Bilateral demographic distance: ΔZ_k = Z_k_i - Z_k_j
    for k in [1, 2, 3]:
        zi = f'Z_{k}_i'
        zj = f'Z_{k}_j'
        if zi in bilateral.columns and zj in bilateral.columns:
            bilateral[f'dZ_{k}'] = bilateral[zi] - bilateral[zj]

    # Interactions: ΔZ_k × KAOPEN_j (destination openness)
    for k in [1, 2, 3]:
        dz = f'dZ_{k}'
        if dz in bilateral.columns and 'kaopen_j' in bilateral.columns:
            bilateral[f'dZ_{k}_x_kaopen_j'] = bilateral[dz] * bilateral['kaopen_j']

    # Log distance
    if 'dist_weighted' in bilateral.columns:
        bilateral['log_dist'] = np.log(bilateral['dist_weighted'].clip(lower=1))

    # Log GDP product: log(GDP_i × GDP_j)
    if 'ngdp_usd_i' in bilateral.columns and 'ngdp_usd_j' in bilateral.columns:
        gdp_prod = bilateral['ngdp_usd_i'] * bilateral['ngdp_usd_j']
        bilateral['log_gdp_product'] = np.log(gdp_prod.clip(lower=1))

    # Scale flows: log(flow) for intensive margin
    flow_cols = []
    for fc in ['portfolio_total', 'portfolio_equity', 'portfolio_debt', 'fdi_outward']:
        if fc in bilateral.columns:
            flow_cols.append(fc)
            positive = bilateral[fc].clip(lower=0)
            bilateral[f'log_{fc}'] = np.log(positive.replace(0, np.nan))
            # Flow / GDP_i ratio
            if 'ngdp_usd_i' in bilateral.columns:
                gdp_i = bilateral['ngdp_usd_i'].replace(0, np.nan)
                bilateral[f'{fc}_gdp_i'] = bilateral[fc] / gdp_i * 100

    # Extensive margin indicators (position exists)
    for fc in flow_cols:
        bilateral[f'has_{fc}'] = (bilateral[fc] > 0).astype(int)

    # Bilateral interest rate differential
    if 'real_bond_10y_diff_i' in bilateral.columns and 'real_bond_10y_diff_j' in bilateral.columns:
        bilateral['rate_diff_ij'] = (bilateral['real_bond_10y_diff_i'] -
                                     bilateral['real_bond_10y_diff_j'])

    # Pair identifier for panel estimation
    bilateral['pair_id'] = bilateral['reporter'] + '_' + bilateral['partner']

    # --- Summary ---
    print(f"\n  Final bilateral panel:")
    print(f"    Observations: {len(bilateral):,}")
    print(f"    Unique pairs: {bilateral['pair_id'].nunique():,}")
    print(f"    Reporters: {bilateral['reporter'].nunique()}")
    print(f"    Partners: {bilateral['partner'].nunique()}")
    print(f"    Years: {bilateral['year'].min():.0f}-{bilateral['year'].max():.0f}")

    for fc in flow_cols:
        n_pos = (bilateral[fc] > 0).sum()
        print(f"    {fc} > 0: {n_pos:,} ({n_pos / len(bilateral) * 100:.1f}%)")

    demo_cov = bilateral[['dZ_1', 'dZ_2', 'dZ_3']].notna().all(axis=1).mean()
    print(f"    Demographic distance coverage: {demo_cov:.1%}")

    # Save
    outfile = PROCESSED_DIR / "bilateral_panel.csv"
    bilateral.to_csv(outfile, index=False)
    print(f"\n  Saved: {outfile}")

    return bilateral


# -----------------------------------------------------------------------
# Main
# -----------------------------------------------------------------------

def main():
    print("=" * 70)
    print("PHASE 1: DOWNLOAD AND CONSTRUCT BILATERAL GRAVITY PANEL")
    print("=" * 70)

    print("\n--- 1a. CPIS (Portfolio Investment) ---")
    cpis = download_cpis()

    print("\n--- 1b. CDIS (Direct Investment) ---")
    cdis = download_cdis()

    print("\n--- 1c. CEPII GeoDist Gravity Variables ---")
    gravity = download_cepii_geodist()

    print("\n--- 1d. Construct Bilateral Panel ---")
    bilateral = construct_bilateral_panel(cpis, cdis, gravity)

    if bilateral is not None:
        print("\n" + "=" * 70)
        print("PHASE 1 COMPLETE")
        print("=" * 70)

        # Save coverage summary
        summary_rows = []
        for fc in ['portfolio_total', 'portfolio_equity', 'portfolio_debt', 'fdi_outward']:
            if fc in bilateral.columns:
                valid = bilateral[fc].notna() & (bilateral[fc] > 0)
                summary_rows.append({
                    'flow_type': fc,
                    'n_obs': valid.sum(),
                    'n_pairs': bilateral.loc[valid, 'pair_id'].nunique(),
                    'n_reporters': bilateral.loc[valid, 'reporter'].nunique(),
                    'n_partners': bilateral.loc[valid, 'partner'].nunique(),
                    'year_min': bilateral.loc[valid, 'year'].min() if valid.any() else np.nan,
                    'year_max': bilateral.loc[valid, 'year'].max() if valid.any() else np.nan,
                })
        summary = pd.DataFrame(summary_rows)
        summary.to_csv(OUTPUT_DIR / "bilateral_coverage.csv", index=False)
        print(f"\n  Coverage summary:")
        print(summary.to_string(index=False))

    return bilateral


if __name__ == "__main__":
    bilateral = main()
